{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook regroups the code sample of the video below, which is a part of the [Hugging Face course](https://huggingface.co/course)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/qgaM0weJHpA?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#@title\n",
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/qgaM0weJHpA?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers and Datasets libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install datasets transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets = load_dataset(\"squad\")\n",
    "raw_datasets = raw_datasets.remove_columns([\"id\", \"title\"])\n",
    "\n",
    "def prepare_data(example):\n",
    "    answer = example[\"answers\"][\"text\"][0]\n",
    "    example[\"answer_start\"] = example[\"answers\"][\"answer_start\"][0]\n",
    "    example[\"answer_end\"] = example[\"answer_start\"] + len(answer)\n",
    "    return example\n",
    "\n",
    "raw_datasets = raw_datasets.map(prepare_data, remove_columns=[\"answers\"])\n",
    "raw_datasets[\"train\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Context: {raw_datasets['train'][0]['context']\")\n",
    "print(f\"Question: {raw_datasets['train'][0]['question']\")\n",
    "start = raw_datasets[\"train\"][0][\"answer_start\"]\n",
    "end = raw_datasets[\"train\"][0][\"answer_end\"]\n",
    "print(f\"\\nAnswer: {raw_datasets['train'][0]['context'][start:end]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"bert-base-cased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
    "\n",
    "example = raw_datasets[\"train\"][0]\n",
    "inputs = tokenizer(\n",
    "    example[\"question\"],\n",
    "    example[\"context\"],\n",
    "    truncation=\"only_second\",\n",
    "    padding=\"max_length\",\n",
    "    max_length=384,\n",
    "    stride=128,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_offsets_mapping=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_labels(offsets, answer_start, answer_end, sequence_ids):\n",
    "    idx = 0\n",
    "    while sequence_ids[idx] != 1:\n",
    "        idx += 1\n",
    "    context_start = idx\n",
    "    while sequence_ids[idx] == 1:\n",
    "        idx += 1\n",
    "    context_end = idx - 1\n",
    "\n",
    "    # If the answer is not fully in the context, return (0, 0)\n",
    "    if offsets[context_start][0] > answer_end or offsets[context_end][1] < answer_start:\n",
    "        return (0, 0)\n",
    "    else:\n",
    "        idx = context_start\n",
    "        while idx <= context_end and offsets[idx][0] <= answer_start:\n",
    "            idx += 1\n",
    "        start_position = idx - 1\n",
    "\n",
    "        idx = context_end\n",
    "        while idx >= context_start and offsets[idx][1] >= answer_end:\n",
    "            idx -= 1\n",
    "        end_position = idx + 1\n",
    "    \n",
    "        return start_position, end_position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start, end = find_labels(\n",
    "    inputs[\"offset_mapping\"][0],\n",
    "    example[\"answer_start\"],\n",
    "    example[\"answer_end\"],\n",
    "    inputs.sequence_ids(0)\n",
    ")\n",
    "tokenizer.decode(inputs[\"input_ids\"][0][start: end+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_training_examples(examples):\n",
    "    questions = [q.strip() for q in examples[\"question\"]]\n",
    "    inputs = tokenizer(\n",
    "        examples[\"question\"],\n",
    "        examples[\"context\"],\n",
    "        truncation=\"only_second\",\n",
    "        padding=\"max_length\",\n",
    "        max_length=384,\n",
    "        stride=128,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "    )\n",
    "    \n",
    "    offset_mapping = inputs.pop(\"offset_mapping\")\n",
    "    sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n",
    "    inputs[\"start_positions\"] = []\n",
    "    inputs[\"end_positions\"] = []\n",
    "\n",
    "    for i, offset in enumerate(offset_mapping):\n",
    "        sample_idx = sample_map[i]\n",
    "        start, end = find_labels(\n",
    "            offset, examples[\"answer_start\"][sample_idx], examples[\"answer_end\"][sample_idx], inputs.sequence_ids(i)\n",
    "        )\n",
    "        \n",
    "        inputs[\"start_positions\"].append(start)\n",
    "        inputs[\"end_positions\"].append(end)\n",
    "\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets = raw_datasets.map(\n",
    "    preprocess_training_examples,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"train\"].column_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "name": "Data processing for Question Answering",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
